CSE 5012 Bioinformatics - Term project¶
M. Ali Osman Atik - 202351075002¶
Article : A network integration approach for drug-target interaction prediction and computational drug repositioning from heterogeneous information¶
PDF : https://www.nature.com/articles/s41467-017-00680-8¶
Github : https://github.com/luoyunan/DTINet?tab=readme-ov-file¶
Reading the data files¶
In [193]:
import pandas as pd
import numpy as np
# Data files
# Note: drugs, proteins, diseases and side-effects are organized in the same order across all files, including name lists, ID mappings and interaction/association matrices.
drug = 'DTINet/data/drug.txt' # list of drug names
protein = 'DTINet/data/protein.txt' # list of protein names
disease = 'DTINet/data/disease.txt' # list of disease names
sideEffect ='DTINet/data/se.txt' # list of side effect names
drugDict = 'DTINet/data/drug_dict_map.txt' # a complete ID mapping between drug names and DrugBank ID
proteinDict = 'DTINet/data/protein_dict_map.txt' # a complete ID mapping between protein names and UniProt ID
mat_drug_drug = 'DTINet/data/mat_drug_drug.txt' # Drug-Drug interaction matrix
mat_drug_protein = 'DTINet/data/mat_drug_protein.txt' # Drug_Protein interaction matrix
mat_drug_protein_rh40 = 'DTINet/data/mat_drug_protein_remove_homo.txt' # Drug_Protein interaction matrix, in which homologous proteins with identity score >40% were excluded
mat_drug_disease = 'DTINet/data/mat_drug_disease.txt' # Drug-Disease association matrix
mat_drug_sideEffect = 'DTINet/data/mat_drug_se.txt' # Drug-SideEffect association matrix - rows are drugs columns are side effects
mat_protein_drug = 'DTINet/data/mat_protein_drug.txt' # Protein-Drug interaction matrix
mat_protein_protein = 'DTINet/data/mat_protein_protein.txt' # Protein-Protein interaction matrix
mat_protein_disease = 'DTINet/data/mat_protein_disease.txt' # Protein-Disease association matrix
simDrugs = 'DTINet/data/Similarity_Matrix_Drugs.txt' # Drug similarity scores based on chemical structures of drugs 0.0 ~ 1.0
simProteins = 'DTINet/data/Similarity_Matrix_Proteins.txt' # Protein similarity scores based on primary sequences of proteins 0.0 ~ 100.0
# pre-trained vector representations for drugs and proteins, which were used to produce the results in the paper
drugVector100 = 'DTINet/feature/drug_vector_d100.txt'
proteinVector400 = 'DTINet/feature/protein_vector_d400.txt'
# Suplemantary Excel file 2
# list of novel drug-target interactions predicted by DTINet,
# which was trained based on all drugs and targets that have at least one known interacting pair.
# Drug ID Drug Name Gene Name Protein ID Protein Name label score
# label=1 means known DTIs in the data
supData2 = 'DTINet/supplementary/Supplementary_Data_2.xlsx'
# Reading name lists ******************************************
ls_drug = []
with open(drug, 'r') as file:
ls_drug = [line.strip() for line in file.readlines()]
ls_protein = []
with open(protein, 'r') as file:
ls_protein = [line.strip() for line in file.readlines()]
ls_disease = []
with open(disease, 'r') as file:
ls_disease = [line.strip() for line in file.readlines()]
ls_sideEffect = []
with open(sideEffect, 'r') as file:
ls_sideEffect = [line.strip() for line in file.readlines()]
# Reading drugDict ********************************************
dict_drug = {}
with open(drugDict, 'r') as f:
drug_dict = {line.strip().split(':')[0]: line.strip().split(':')[1] for line in f}
# Reading proteinDict *****************************************
dict_protein = {}
with open(proteinDict, 'r') as f:
dict_protein = {line.strip().split(':')[0]: line.strip().split(':')[1] for line in f}
# Reading association matrices to numpy arrays ****************
np_mat_drug_drug = np.loadtxt(mat_drug_drug, delimiter=' ', dtype=int)
np_mat_drug_protein = np.loadtxt(mat_drug_protein, delimiter=' ', dtype=int)
np_mat_drug_protein_rh40 = np.loadtxt(mat_drug_protein_rh40, delimiter='\t', dtype=int)
np_mat_drug_disease = np.loadtxt(mat_drug_disease, delimiter=' ', dtype=int)
np_mat_drug_sideEffect = np.loadtxt(mat_drug_sideEffect, delimiter=' ', dtype=int)
np_mat_protein_drug = np.loadtxt(mat_protein_drug, delimiter=' ', dtype=int)
np_mat_protein_protein = np.loadtxt(mat_protein_protein, delimiter=' ', dtype=int)
np_mat_protein_disease = np.loadtxt(mat_protein_disease, delimiter=' ', dtype=int)
# Reading similarity matrices *********************************
np_mat_simDrugs = np.loadtxt(simDrugs)
np_mat_simProteins = np.loadtxt(simProteins, delimiter=' ')
# Reading pretrained features
drugVec100 = np.loadtxt(drugVector100)
proteinVec400 = np.loadtxt(proteinVector400)
# Reading suplemantary Excel file 2 ***************************
df_sd2 = pd.read_excel(supData2, usecols=[0, 1, 2, 3, 4, 5, 6], engine='openpyxl')
print("All files are read")
All files are read
Data samples¶
In [2]:
ls_drug[:5]
Out[2]:
['DB00050', 'DB00152', 'DB00162', 'DB00175', 'DB00176']
In [3]:
ls_protein[:5]
Out[3]:
['Q9UI32', 'P00488', 'P35228', 'P06737', 'P11766']
In [4]:
ls_disease[:10]
Out[4]:
['depressive disorder', 'drug-induced liver injury', 'mercury poisoning', 'necrosis', 'neoplasms', 'anemia, hemolytic', 'attention deficit and disruptive behavior disorders', 'autistic disorder', 'cognition disorders', 'cystitis']
In [5]:
ls_sideEffect[:5]
Out[5]:
['cerebrovascular accident', 'rash', 'ptosis', 'paresthesia', 'bronchospasm']
In [6]:
for key, value in list(drug_dict.items())[:5]:
print(f"{key}: {value}")
DB00001: Lepirudin DB00002: Cetuximab DB00003: Dornase Alfa DB00004: Denileukin diftitox DB00005: Etanercept
In [7]:
for key, value in list(dict_protein.items())[:5]:
print(f"{key}: {value}")
P45059: ftsI P19113: HDC Q9UI32: GLS2 P00488: F13A1 P35228: NOS2
In [8]:
np_mat_drug_drug[:5, :10]
Out[8]:
array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 1]])
In [9]:
np_mat_drug_protein[:5, :10]
Out[9]:
array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])
In [10]:
np_mat_drug_protein_rh40[:5, :10]
Out[10]:
array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])
In [11]:
np_mat_drug_disease[:5, :10]
Out[11]:
array([[1, 1, 1, 1, 1, 0, 0, 1, 1, 0],
[1, 1, 0, 0, 1, 0, 0, 1, 1, 0],
[1, 1, 0, 1, 1, 1, 0, 1, 1, 1],
[1, 1, 0, 1, 1, 1, 0, 1, 1, 0],
[1, 1, 0, 1, 0, 0, 1, 1, 0, 0]])
In [12]:
np_mat_drug_sideEffect[:5, :10]
Out[12]:
array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 1, 0, 1, 0, 0, 0, 0, 0, 0],
[1, 0, 0, 1, 0, 0, 0, 0, 1, 1]])
In [13]:
np_mat_protein_drug[:5, :10]
Out[13]:
array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])
In [14]:
np_mat_protein_protein[:5, :10]
Out[14]:
array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 1, 0, 0, 0, 0, 0]])
In [15]:
np_mat_protein_disease[:5, :10]
Out[15]:
array([[1, 1, 0, 1, 1, 1, 0, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
In [16]:
np_mat_simDrugs[:5, :10]
Out[16]:
array([[1. , 0.23333333, 0.12745099, 0.14166667, 0.22314049,
0.42857143, 0.55769229, 0.16788322, 0.40625 , 0.26666668],
[0.23333333, 1. , 0.07142857, 0.1007752 , 0.14179105,
0.18243243, 0.16783217, 0.10738255, 0.14754099, 0.11214953],
[0.12745099, 0.07142857, 1. , 0.5 , 0.12162162,
0.14130434, 0.15476191, 0.26666668, 0.16666667, 0.11363637],
[0.14166667, 0.1007752 , 0.5 , 1. , 0.14130434,
0.20952381, 0.29347825, 0.44444445, 0.24324325, 0.07575758],
[0.22314049, 0.14179105, 0.12162162, 0.14130434, 1. ,
0.2 , 0.30000001, 0.20754717, 0.30379745, 0.32258064]])
In [17]:
np_mat_simProteins[:5, :10]
Out[17]:
array([[100. , 11.864407, 15.819209, 14.124294, 10.734463,
10.169492, 10.734463, 15.819209, 12.429379, 11.864407],
[ 11.864407, 100. , 11.338798, 9.972678, 10.427807,
12.295082, 13.496933, 9.42623 , 12.451362, 11.574074],
[ 15.819209, 11.338798, 100. , 10.743802, 10.427807,
10.928962, 11.656442, 9.874327, 12.840467, 12.962963],
[ 14.124294, 9.972678, 10.743802, 100. , 12.032086,
13.114754, 11.96319 , 10.625738, 12.451362, 13.425926],
[ 10.734463, 10.427807, 10.427807, 12.032086, 100. ,
11.202186, 11.042945, 12.299465, 13.229572, 12.037037]])
In [18]:
df_sd2[:5]
Out[18]:
| Drug ID | Drug Name | Gene Name | Protein ID | Protein Name | label | score | |
|---|---|---|---|---|---|---|---|
| 0 | DB06274 | Alvimopan | OPRD1 | P41143 | Delta-type opioid receptor | 1 | 1.6778 |
| 1 | DB06274 | Alvimopan | OPRK1 | P41145 | Kappa-type opioid receptor | 1 | 1.3948 |
| 2 | DB06274 | Alvimopan | OPRM1 | P35372 | Mu-type opioid receptor | 1 | 1.3287 |
| 3 | DB00246 | Ziprasidone | HTR1D | P28221 | 5-hydroxytryptamine receptor 1D | 1 | 1.0877 |
| 4 | DB00246 | Ziprasidone | HTR1B | P28222 | 5-hydroxytryptamine receptor 1B | 1 | 1.0663 |
Protein similarities¶
In [113]:
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
# Similarity threshold
similarity_threshold = 0.9
# Create an undirected graph
G = nx.Graph()
# Add nodes (proteins) to the graph
G.add_nodes_from(ls_protein)
# Add edges for pairs of proteins with similarity greater than the threshold
for i in range(np_mat_simProteins.shape[0]):
for j in range(i + 1, np_mat_simProteins.shape[1]): # Only consider the upper triangle
if np_mat_simProteins[i, j] > similarity_threshold*100:
G.add_edge(ls_protein[i], ls_protein[j], weight=np_mat_simProteins[i, j])
# Remove isolated nodes from the graph
G.remove_nodes_from(list(nx.isolates(G)))
# Draw the filtered graph
plt.figure(figsize=(20, 10))
# Draw the graph
pos = nx.spring_layout(G, k=1, iterations=10) # Adjust k for node spacing
# Draw nodes and labels
nx.draw_networkx_nodes(G, pos, node_size=700, node_color='pink')
nx.draw_networkx_labels(G, pos, font_size=6, font_family='sans-serif')
nx.draw_networkx_edges(G, pos, width=1, edge_color='g')
plt.title("Protein Similarity Graph")
plt.savefig('DTINet/plots/protein_similarities.png', format='png', dpi=300)
plt.show()
Drug similarities¶
In [110]:
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
# Similarity threshold
similarity_threshold = 0.9
# Create an undirected graph
G = nx.Graph()
# Add nodes (proteins) to the graph
G.add_nodes_from(ls_drug)
# Add edges for pairs of proteins with similarity greater than the threshold
for i in range(np_mat_simDrugs.shape[0]):
for j in range(i + 1, np_mat_simDrugs.shape[1]): # Only consider the upper triangle
if np_mat_simDrugs[i, j] >= similarity_threshold:
G.add_edge(ls_drug[i], ls_drug[j], weight=np_mat_simDrugs[i, j])
# Remove isolated nodes from the graph
G.remove_nodes_from(list(nx.isolates(G)))
# Draw the filtered graph
plt.figure(figsize=(20, 10))
# Draw the graph
pos = nx.spring_layout(G, k=1, iterations=10) # Adjust k for node spacing
# Draw nodes and labels
nx.draw_networkx_nodes(G, pos, node_size=700, node_color='lightgreen')
nx.draw_networkx_labels(G, pos, font_size=6, font_family='sans-serif')
nx.draw_networkx_edges(G, pos, width=1, edge_color='g')
plt.title("Drug Similarity Graph")
plt.savefig('DTINet/plots/drug_similarities.png', format='png', dpi=300)
plt.show()
Drug - protein interactions¶
In [115]:
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt
sub_sd2 = df_sd2[df_sd2["score"] > 0.7]
# Create a graph
G = nx.Graph()
# Add edges with weights to the graph
for index, row in sub_sd2.iterrows():
G.add_edge(row["Drug ID"], row["Protein ID"], weight=row["score"])
# Separate nodes into Drug IDs and Protein IDs
drug_nodes = set(sub_sd2["Drug ID"])
protein_nodes = set(sub_sd2["Protein ID"])
# Assign colors based on node type
node_colors = []
for node in G.nodes():
if node in drug_nodes:
node_colors.append('lightgreen')
elif node in protein_nodes:
node_colors.append('pink')
# Draw the graph
plt.figure(figsize=(20, 10))
pos = nx.spring_layout(G, k=10, iterations=10) # Adjust k for node spacing
# Draw the nodes with assigned colors
nx.draw_networkx_nodes(G, pos, node_size=700, node_color=node_colors)
# Draw the edges with weights as the width
edges = G.edges(data=True)
widths = [d['weight'] * 1 for (u, v, d) in edges] # Adjust the multiplier for better visualization
nx.draw_networkx_edges(G, pos, edgelist=edges, width=widths, edge_color='g')
# Draw node labels
nx.draw_networkx_labels(G, pos, font_size=6, font_family="sans-serif")
# Display the graph
plt.title("Drug-Protein Interaction Graph")
plt.savefig('DTINet/plots/drug_protein_interaction.png', format='png', dpi=300)
plt.show()
In [21]:
import torch
from torch.optim import LBFGS
import numpy as np
import time
# Check if CUDA is available and select the appropriate device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using : ", device)
# Convert input data to PyTorch tensors and move them to the selected device
X = torch.tensor(drugVec100, dtype=torch.float32).to(device)
Y = torch.tensor(proteinVec400, dtype=torch.float32).to(device)
P = torch.tensor(np_mat_drug_protein, dtype=torch.float32).to(device)
# Define the cost function for optimization
def cost_function(Z_vec):
Z = Z_vec.view(100, 400)
reconstruction = torch.matmul(torch.matmul(X, Z), Y.transpose(0, 1))
return torch.norm(P - reconstruction, 'fro')**2
# Initial guess for Z, moved to the selected device
Z_initial = torch.rand(100, 400, dtype=torch.float32, requires_grad=True).to(device)
# Measure runtime
start_time = time.time()
# Detach the tensor from the computation graph to avoid the "non-leaf Tensor" error
Z_initial = Z_initial.detach().requires_grad_(True)
# Minimize the cost function using LBFGS optimizer from PyTorch
optimizer = LBFGS([Z_initial], lr=1)
# Define the closure function for LBFGS optimizer
def closure():
optimizer.zero_grad()
loss = cost_function(Z_initial.flatten())
loss.backward()
return loss
optimizer.step(closure)
end_time = time.time()
runtime = end_time - start_time
# Retrieve the optimized Z
Z_optimized = Z_initial.detach().cpu().numpy().reshape(100, 400)
# Save the computed Z matrix to a file
output_path = 'DTINet/data/optimized_Z.npy'
np.save(output_path, Z_optimized)
print(f"Runtime: {runtime} seconds")
print(f"Z matrix saved to {output_path}")
print("Computed Z:")
print(Z_optimized)
Using : cpu Runtime: 1.3358736038208008 seconds Z matrix saved to DTINet/data/optimized_Z.npy Computed Z: [[-0.06226376 0.03026691 -0.00944128 ... 0.11185674 0.17309016 0.01962686] [ 0.18468113 0.06588919 -0.06278056 ... -0.05458539 -0.13401051 -0.25046456] [-0.01374707 0.15950859 -0.01560685 ... -0.26745644 -0.07330911 -0.18190956] ... [-0.18878788 -0.01065327 -0.17394738 ... 0.02563913 0.07577506 0.11255871] [-0.19921055 -0.11214258 -0.13700141 ... 0.533817 0.6438616 0.6731604 ] [-0.16666351 -0.10173433 -0.1739214 ... 0.36676386 0.5010344 0.58837444]]
In [163]:
import numpy as np
from sklearn.metrics import roc_auc_score, average_precision_score, accuracy_score
# Calculate the predicted matrix P_predicted
P_predicted = np.matmul(np.matmul(X.cpu().numpy(), Z_optimized), Y.cpu().numpy().T)
# Apply a threshold to binarize the predicted matrix
threshold = 0.9
P_predicted_binary = (P_predicted > threshold).astype(int)
# Define evaluation metrics functions
def mean_squared_error(y_true, y_pred):
return np.mean((y_true - y_pred)**2)
def mean_absolute_error(y_true, y_pred):
return np.mean(np.abs(y_true - y_pred))
# Define evaluation metrics functions for binary matrices
def accuracy(y_true, y_pred):
return np.mean(y_true == y_pred)
# Flatten matrices for metric calculation
P_flat = P.cpu().numpy().flatten()
P_predicted_binary_flat = P_predicted_binary.flatten()
P_predicted_flat = P_predicted.flatten()
# Calculate evaluation metrics
mse = mean_squared_error(P_flat, P_predicted_binary_flat)
mae = mean_absolute_error(P_flat, P_predicted_binary_flat)
accuracy_score = accuracy(P_flat, P_predicted_binary_flat)
# Calculate AUROC and AUPR
auroc = roc_auc_score(P_flat, P_predicted_flat)
aupr = average_precision_score(P_flat, P_predicted_flat)
print("Evaluation Metrics:")
print(f"Mean Squared Error (MSE): {mse}")
print(f"Mean Absolute Error (MAE): {mae}")
print(f"Accuracy Score: {accuracy_score}")
print(f"AUROC: {auroc}")
print(f"AUPR: {aupr}")
# Plot the evaluation metrics
import matplotlib.pyplot as plt
metrics = ['MSE', 'MAE', 'Accuracy', 'AUROC', 'AUPR']
values = [mse, mae, accuracy_score, auroc, aupr]
plt.figure(figsize=(12, 6))
plt.bar(metrics, values, color=['blue', 'green', 'red', 'purple', 'orange'])
plt.xlabel('Metrics')
plt.ylabel('Scores')
plt.title('DTINet Model Evaluation Metrics')
plt.ylim(0, 1) # All metrics are within the range 0 to 1 for binary classification
for i, v in enumerate(values):
plt.text(i, v + 0.02, f"{v:.2f}", ha='center', va='bottom')
plt.show()
Evaluation Metrics: Mean Squared Error (MSE): 0.1288617612770155 Mean Absolute Error (MAE): 0.1288617612770155 Accuracy Score: 0.8711382387229845 AUROC: 0.5978626522806747 AUPR: 0.0025473870075329967
Link prediction with pretrained feature matrices - RFC¶
In [25]:
# Create Graph
G = nx.Graph()
# Add drug nodes
for i, drug in enumerate(ls_drug):
G.add_node(drug, bipartite=0)
# Add protein nodes
for i, protein in enumerate(ls_protein):
G.add_node(protein, bipartite=1)
# Add edges based on the interaction matrix
for i, drug in enumerate(ls_drug):
for j, protein in enumerate(ls_protein):
if np_mat_drug_protein[i, j] == 1:
G.add_edge(drug, protein)
In [355]:
# Feature Preparation
from sklearn.preprocessing import StandardScaler
# Combine drug and protein features into a feature matrix
drug_features = drugVec100
protein_features = proteinVec400
scaler = StandardScaler()
drug_features = scaler.fit_transform(drug_features)
protein_features = scaler.fit_transform(protein_features)
# Create a feature matrix for each drug-protein pair
def create_feature_matrix(drug_features, protein_features):
features = []
labels = []
pairs = []
for i, drug in enumerate(ls_drug):
for j, protein in enumerate(ls_protein):
features.append(np.concatenate((drug_features[i], protein_features[j])))
labels.append(np_mat_drug_protein[i, j])
pairs.append((drug, protein))
return np.array(features), np.array(labels), pairs
features, labels, pairs = create_feature_matrix(drug_features, protein_features)
In [28]:
# Train-Test Split
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test, pairs_train, pairs_test = train_test_split(features, labels, pairs, test_size=0.2, random_state=42)
In [31]:
# Training
from tqdm import tqdm
from sklearn.ensemble import RandomForestClassifier
# Function to update tqdm bar
class tqdmBatchProgressBar:
def __init__(self, total):
self.pbar = tqdm(total=total)
def __call__(self, clf, X, y, sample_weight=None):
self.pbar.update()
# Initialize tqdm progress bar
progress_bar = tqdmBatchProgressBar(total=clf.n_estimators)
# Initialize and train RandomForestClassifier with progress bar
clf = RandomForestClassifier(n_estimators=100, random_state=42, warm_start=True, n_jobs=-1)
for i in range(clf.n_estimators):
clf.n_estimators = i + 1
clf.fit(X_train, y_train)
progress_bar(clf, X_train, y_train)
progress_bar.pbar.close()
100%|██████████████████████████████████████████████████████████████████████████████| 100/100 [1:10:45<00:00, 42.45s/it]
In [32]:
# Save the model
import joblib
model_path = 'DTINet/data/predictionModel_RFC.pkl'
joblib.dump(clf, model_path)
print(f"Model saved to {model_path}")
Model saved to DTINet/data/predictionModel_RFC.pkl
In [194]:
# Load the model
import joblib
model_path = 'DTINet/data/predictionModel_RFC.pkl'
clf = joblib.load(model_path)
print("Model loaded successfully")
Model loaded successfully
In [195]:
# Evaluation
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score, average_precision_score, accuracy_score, recall_score, f1_score
# Calculate predicted probabilities and binary predictions
y_pred_proba = clf.predict_proba(X_test)[:, 1]
y_pred = clf.predict(X_test)
# Compute evaluation metrics
roc_auc = roc_auc_score(y_test, y_pred_proba)
avg_precision = average_precision_score(y_test, y_pred_proba)
accuracy = accuracy_score(y_test, y_pred)
recall = recall_score(y_test, y_pred)
f1 = f1_score(y_test, y_pred)
# Print the evaluation metrics
print(f"ROC-AUC: {roc_auc}")
print(f"Average Precision: {avg_precision}")
print(f"Accuracy: {accuracy}")
print(f"Recall: {recall}")
print(f"F1 Score: {f1}")
# Plot the evaluation metrics
metrics = ['ROC-AUC', 'Average Precision', 'Accuracy', 'Recall', 'F1 Score']
values = [roc_auc, avg_precision, accuracy, recall, f1]
plt.figure(figsize=(10, 6))
plt.bar(metrics, values, color=['blue', 'green', 'red', 'purple', 'orange'])
plt.xlabel('Metrics')
plt.ylabel('Scores')
plt.title('RFC Model Evaluation')
plt.ylim(0, 1) # All metrics are within the range 0 to 1
for i, v in enumerate(values):
plt.text(i, v - 0.05, f"{v:.3f}", ha='center', va='bottom')
plt.savefig('DTINet/plots/_model_evaluation.png', format='png', dpi=300)
plt.show()
ROC-AUC: 0.4999952476903775 Average Precision: 0.003834609958624085 Accuracy: 0.9961653900413759 Recall: 0.0 F1 Score: 0.0
In [196]:
# Predict Potential Links
y_pred = clf.predict(X_test)
predicted_links = [(pairs_test[i][0], pairs_test[i][1]) for i in range(len(y_pred)) if y_pred[i] == 1 and y_test[i] == 0]
print("Predicted new links:")
for link in predicted_links:
print(link)
Predicted new links:
In [119]:
import matplotlib.pyplot as plt
# Create a new graph for predicted links
G_predicted = nx.Graph()
# Add predicted nodes and edges
for u, v in predicted_links:
G_predicted.add_node(u, bipartite=0)
G_predicted.add_node(v, bipartite=1)
G_predicted.add_edge(u, v)
# All edges in data
# all_known_edges = [(ls_drug[i], ls_protein[j]) for i, j in zip(*np.where(np_mat_drug_protein == 1))]
# Add known edges between the predicted nodes
known_edges = [(u, v) for u, v in predicted_links if np_mat_drug_protein[ls_drug.index(u)][ls_protein.index(v)] == 1]
# Draw the predicted links graph
plt.figure(figsize=(20, 10))
pos = nx.spring_layout(G_predicted, k=1, iterations=20) # Adjust k for node spacing
# Draw known edges in green
nx.draw_networkx_edges(G_predicted, pos, edgelist=known_edges, edge_color='green', width=1)
# Draw predicted links in red
predicted_edges = [(u, v) for u, v in predicted_links if np_mat_drug_protein[ls_drug.index(u)][ls_protein.index(v)] == 0]
nx.draw_networkx_edges(G_predicted, pos, edgelist=predicted_edges, edge_color='red', width=1)
# Draw nodes with different colors for drugs and proteins
drug_nodes = [node for node in G_predicted.nodes if node in ls_drug]
protein_nodes = [node for node in G_predicted.nodes if node in ls_protein]
nx.draw_networkx_nodes(G_predicted, pos, nodelist=drug_nodes, node_color='lightgreen', node_size=700, label='Drugs')
nx.draw_networkx_nodes(G_predicted, pos, nodelist=protein_nodes, node_color='pink', node_size=700, label='Proteins')
# Draw labels
nx.draw_networkx_labels(G_predicted, pos, font_size=6, font_family='sans-serif')
plt.title("Predicted Drug-Protein Interaction Network")
plt.legend(markerscale=0.35)
plt.savefig('DTINet/plots/predicted_drug_protein_interaction.png', format='png', dpi=300)
plt.show()
In [153]:
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
# Create a new graph for predicted links
G_predicted = nx.Graph()
# Add predicted nodes and edges
for u, v in predicted_links:
G_predicted.add_node(u, bipartite=0)
G_predicted.add_node(v, bipartite=1)
G_predicted.add_edge(u, v)
# All edges in data
all_known_edges = [(ls_drug[i], ls_protein[j]) for i, j in zip(*np.where(np_mat_drug_protein == 1))]
# Add known edges between the predicted nodes
known_edges = [(u, v) for u, v in predicted_links if np_mat_drug_protein[ls_drug.index(u)][ls_protein.index(v)] == 1]
# Identify nodes involved in predicted edges
predicted_nodes = set([u for u, v in predicted_links] + [v for u, v in predicted_links])
# Identify all known edges connected to predicted nodes but not in predicted edges
additional_edges = [
(u, v) for u, v in all_known_edges
if (u in predicted_nodes or v in predicted_nodes) and (u, v) not in predicted_edges and (u, v) not in known_edges
]
# Ensure all nodes in additional_edges are added to the graph
for u, v in additional_edges:
if u not in G_predicted:
G_predicted.add_node(u, bipartite=0 if u in ls_drug else 1)
if v not in G_predicted:
G_predicted.add_node(v, bipartite=0 if v in ls_drug else 1)
# Draw the predicted links graph
plt.figure(figsize=(20, 10))
pos = nx.spring_layout(G_predicted, k=0.1, iterations=5) # Adjust k for node spacing
# Draw known edges in green
nx.draw_networkx_edges(G_predicted, pos, edgelist=known_edges, edge_color='green', width=2)
# Draw predicted links in red
predicted_edges = [(u, v) for u, v in predicted_links if np_mat_drug_protein[ls_drug.index(u)][ls_protein.index(v)] == 0]
nx.draw_networkx_edges(G_predicted, pos, edgelist=predicted_edges, edge_color='red', width=2)
# Draw additional known edges in black
nx.draw_networkx_edges(G_predicted, pos, edgelist=additional_edges, edge_color='gray', width=0.2)
# Draw nodes with different colors for drugs and proteins
drug_nodes = [node for node in G_predicted.nodes if node in ls_drug]
protein_nodes = [node for node in G_predicted.nodes if node in ls_protein]
nx.draw_networkx_nodes(G_predicted, pos, nodelist=drug_nodes, node_color='lightgreen', node_size=700, label='Drugs')
nx.draw_networkx_nodes(G_predicted, pos, nodelist=protein_nodes, node_color='pink', node_size=700, label='Proteins')
# Draw labels
nx.draw_networkx_labels(G_predicted, pos, font_size=6, font_family='sans-serif')
plt.title("Predicted Drug-Protein Interaction Network")
plt.legend(markerscale=0.35)
plt.savefig('DTINet/plots/predicted_drug_protein_interaction_network.png', format='png', dpi=300)
plt.show()
Link prediction with pretrained feature matrices - SGDC¶
In [201]:
import networkx as nx
# Create bipartite graph
G = nx.Graph()
for i, drug in enumerate(ls_drug):
G.add_node(drug, bipartite=0)
for i, protein in enumerate(ls_protein):
G.add_node(protein, bipartite=1)
for i, drug in enumerate(ls_drug):
for j, protein in enumerate(ls_protein):
if np_mat_drug_protein[i, j] == 1:
G.add_edge(drug, protein)
In [202]:
# Prepare data for link prediction
drug_features = {ls_drug[i]: drugVec100[i] for i in range(len(ls_drug))}
protein_features = {ls_protein[i]: proteinVec400[i] for i in range(len(ls_protein))}
edges = list(G.edges())
non_edges = list(nx.non_edges(G))
X = []
y = []
missing_drugs = set()
missing_proteins = set()
for edge in edges:
drug, protein = edge
if drug not in drug_features:
missing_drugs.add(drug)
continue
if protein not in protein_features:
missing_proteins.add(protein)
continue
X.append(np.concatenate((drug_features[drug], protein_features[protein])))
y.append(1)
for edge in non_edges:
drug, protein = edge
if drug not in drug_features:
missing_drugs.add(drug)
continue
if protein not in protein_features:
missing_proteins.add(protein)
continue
X.append(np.concatenate((drug_features[drug], protein_features[protein])))
y.append(0)
X = np.array(X)
y = np.array(y)
if missing_drugs:
print(f"Missing drugs: {missing_drugs}")
if missing_proteins:
print(f"Missing proteins: {missing_proteins}")
Missing drugs: {'P02787', 'P09601', 'P15121', 'O43741', 'Q9GZU7', 'P98160', 'Q05655', 'P49247', 'Q9NWZ3', 'Q9NR96', 'Q9H6Z9', 'P42262', 'P48544', 'Q96GD4', 'Q09013', 'O43837', 'P19835', 'P05108', 'O95180', 'Q92930', 'Q96QU6', 'P09211', 'Q9H244', 'P62913', 'P50416', 'Q7Z406', 'Q14289', 'P13727', 'P49761', 'P29460', 'P07478', 'P55072', 'Q03154', 'P31937', 'P15692', 'Q06187', 'P40937', 'Q9Y5Y6', 'Q02817', 'P61278', 'P09210', 'O00555', 'Q9Y5Y9', 'P01138', 'P31153', 'O75460', 'Q8NFA2', 'Q92830', 'P18031', 'O00204', 'P20618', 'P09488', 'O75874', 'P51606', 'Q9BZR6', 'P35499', 'P13726', 'P01112', 'Q9NY72', 'P04745', 'Q9H169', 'O43570', 'Q15067', 'P21728', 'O14818', 'P12235', 'P05141', 'P35626', 'Q9NRE1', 'O14949', 'P24855', 'Q13258', 'P13945', 'P29275', 'P04070', 'P09848', 'P14735', 'P00505', 'P48664', 'P29120', 'P08588', 'P52799', 'P06756', 'P18283', 'P00367', 'P45983', 'O15239', 'Q99624', 'O75865', 'P16112', 'Q14957', 'Q13285', 'Q5T6L4', 'P46098', 'P07327', 'P29323', 'Q9UBX1', 'Q12797', 'P05451', 'P04632', 'P11586', 'Q9UNI1', 'P54760', 'P25774', 'Q9NYK1', 'P08637', 'P50750', 'P41250', 'Q8WVQ1', 'P23467', 'P08473', 'P31930', 'P06493', 'P17900', 'Q13131', 'P68366', 'Q8TCU5', 'O00459', 'P08235', 'P29074', 'Q15118', 'P34972', 'P10721', 'P35368', 'P12532', 'P07384', 'P02675', 'P10912', 'Q9Y5K3', 'P29466', 'P43235', 'Q12809', 'P22748', 'P04053', 'P41143', 'Q08188', 'P27986', 'O60391', 'P78540', 'Q03431', 'P01137', 'P08574', 'P49411', 'P19784', 'O00187', 'P68400', 'P49069', 'Q04844', 'Q13363', 'P27707', 'P22894', 'P38484', 'P21918', 'P63211', 'P43088', 'P02775', 'P00390', 'P11474', 'P11230', 'P01100', 'P17787', 'P11488', 'Q6N063', 'P36888', 'Q13426', 'P21917', 'Q9BYF1', 'P16070', 'Q9P2W7', 'P51553', 'P02735', 'P30613', 'P30419', 'O96000', 'Q14494', 'P60568', 'P51168', 'P04626', 'P43005', 'Q9UHW9', 'P07451', 'P10826', 'O95665', 'Q01362', 'P60174', 'Q9Y286', 'P11021', 'Q15375', 'P01834', 'Q96FI4', 'P01133', 'O95237', 'P25787', 'Q07837', 'Q9UGN5', 'P11712', 'Q14500', 'P48039', 'Q99519', 'P27695', 'P42768', 'P23381', 'P51172', 'Q09428', 'Q9UHQ9', 'P05546', 'P53674', 'O95622', 'P30519', 'O15460', 'P00750', 'P07949', 'P51692', 'O60493', 'P61160', 'P48051', 'P34913', 'P37173', 'P40616', 'P29474', 'P12318', 'P13500', 'P35346', 'O00469', 'P06280', 'P06576', 'P19838', 'P21589', 'Q92947', 'Q9Y6Y9', 'P56817', 'P24666', 'P43007', 'P99999', 'Q86YB8', 'Q16773', 'P49721', 'O75469', 'Q15822', 'Q8IWL2', 'P47901', 'P51003', 'Q16558', 'P50579', 'P04054', 'Q13002', 'P30305', 'P80404', 'Q9NQX3', 'P30085', 'P61626', 'P33981', 'P49902', 'P41235', 'P13631', 'P01584', 'P48169', 'P25098', 'P21549', 'P61073', 'P30874', 'P10275', 'P01730', 'P67775', 'P32929', 'P28070', 'Q9BZZ2', 'P35754', 'P10827', 'Q8WTV0', 'P50213', 'P03372', 'P14324', 'Q4U2R8', 'P22459', 'P62714', 'Q02218', 'P40925', 'P11388', 'P06744', 'P22455', 'Q03518', 'P00734', 'O15217', 'P06870', 'P23416', 'P04181', 'P45844', 'Q99460', 'P84077', 'P78559', 'P06400', 'O15143', 'P19438', 'P05091', 'P34995', 'O75380', 'P35241', 'Q13200', 'Q9GZT9', 'P33402', 'P56282', 'P39086', 'P62158', 'Q9NWT6', 'P04035', 'Q16644', 'Q07912', 'Q9H2G2', 'P50052', 'P68133', 'P42892', 'P00441', 'P13639', 'Q16613', 'P12314', 'P42345', 'Q9NR19', 'P10415', 'P35270', 'P14868', 'P24385', 'P19429', 'P22830', 'P18509', 'P17050', 'P23743', 'P21980', 'P08183', 'P29034', 'Q13748', 'P51787', 'P16152', 'P22695', 'Q9UI32', 'Q9Y345', 'P11168', 'P42336', 'P34969', 'O95838', 'P25789', 'Q99873', 'P23141', 'P01008', 'Q96I99', 'Q9UQF2', 'Q13526', 'Q92696', 'Q9NR97', 'P05412', 'P14679', 'P07195', 'P68871', 'Q99542', 'Q9H2S1', 'P08908', 'P00480', 'Q8NI22', 'Q9Y6M9', 'P08246', 'Q9H4M7', 'P28476', 'Q9NQS7', 'Q9UKQ2', 'P49448', 'O14939', 'P22891', 'A6NG28', 'P38435', 'P48061', 'P09874', 'Q13332', 'O14649', 'O15530', 'P54284', 'P51659', 'P30273', 'P15104', 'P30968', 'P63316', 'P02794', 'P54710', 'Q15181', 'P04818', 'O15399', 'Q9NTG7', 'P27487', 'P31751', 'O43175', 'P08254', 'Q13509', 'P28482', 'P01374', 'P41146', 'P08684', 'P08100', 'P35237', 'Q6YP21', 'P06132', 'P11684', 'Q16739', 'Q15102', 'O60512', 'P35219', 'O15554', 'P18858', 'P17181', 'P63000', 'P36544', 'P00918', 'P04271', 'Q8WUI4', 'P22234', 'P26639', 'P17752', 'P30838', 'P02746', 'O95342', 'P10253', 'Q14524', 'P02679', 'P43080', 'P02788', 'O60547', 'Q96KC2', 'P08238', 'Q09470', 'Q16831', 'P07225', 'O15169', 'Q92887', 'P02708', 'P12931', 'P34741', 'Q15648', 'P24046', 'P43220', 'P61024', 'Q13224', 'Q9NVH6', 'P31994', 'P22392', 'P31350', 'P55011', 'O14732', 'P03950', 'Q16270', 'P01906', 'P18089', 'P54687', 'P55196', 'Q9H3N8', 'P09619', 'Q02880', 'P51170', 'P28223', 'P15382', 'O60760', 'O14880', 'Q13627', 'P27338', 'Q9H228', 'Q13133', 'P17405', 'P01579', 'Q92569', 'P30872', 'Q13003', 'Q04760', 'P12271', 'Q05940', 'P08514', 'Q9BQE3', 'P11836', 'Q9Y2I1', 'Q9Y285', 'Q9H015', 'Q06520', 'P09038', 'Q9P0J0', 'P04234', 'P12277', 'Q9UBS5', 'P51178', 'P53355', 'P54368', 'Q96HD9', 'P11142', 'P98164', 'O60706', 'Q9Y6F1', 'P16109', 'P00749', 'P47985', 'P38606', 'P15170', 'P15559', 'P10828', 'P36405', 'P23280', 'P49336', 'Q10588', 'P04062', 'Q9NS18', 'Q9H4B7', 'P23677', 'Q9Y234', 'Q8N159', 'Q08881', 'Q12756', 'P35354', 'P07550', 'P00751', 'P41231', 'Q8IUZ5', 'P15291', 'P04179', 'P27361', 'P04746', 'Q9NY46', 'P55789', 'P55210', 'Q99808', 'P22413', 'P49354', 'P42574', 'O95190', 'Q16478', 'P68363', 'Q13362', 'P54278', 'Q99062', 'P33681', 'Q08462', 'Q7LG56', 'Q00534', 'P20823', 'P00746', 'P01023', 'O15270', 'P09622', 'P19634', 'P00742', 'P11473', 'P48745', 'O60896', 'Q12791', 'P14618', 'Q8N142', 'P05067', 'Q9BW91', 'Q969V6', 'Q9UHI5', 'P07711', 'P11926', 'P63098', 'P59998', 'Q07001', 'P56556', 'P20309', 'P48549', 'P41240', 'P29274', 'P09884', 'Q13956', 'Q8N1C3', 'P56181', 'Q13621', 'P42330', 'P19622', 'P54750', 'O43739', 'P06241', 'O43497', 'O60674', 'P19235', 'P41091', 'O75438', 'Q08257', 'P16410', 'P28161', 'Q8IVA8', 'Q96EY8', 'P04049', 'Q15418', 'Q9HB21', 'P00491', 'O75251', 'P26572', 'P35372', 'O76082', 'Q7L0J3', 'Q00975', 'P05019', 'P31939', 'Q6IB77', 'Q96RD7', 'P08912', 'Q99714', 'P00736', 'P09466', 'P19971', 'O15066', 'P49588', 'P20674', 'P26440', 'Q99707', 'P10600', 'P37288', 'P40238', 'P05093', 'P51512', 'P09382', 'O75015', 'P37088', 'P11226', 'P78330', 'P00450', 'P41594', 'P23378', 'Q06124', 'P16083', 'P01270', 'P15529', 'P51681', 'Q14181', 'P22102', 'P49589', 'O95831', 'P08311', 'P01308', 'P29597', 'P41743', 'P21673', 'O43676', 'Q9NPC2', 'Q99835', 'P13612', 'P48637', 'Q9NPB1', 'P05771', 'Q16620', 'P17612', 'P61221', 'P07900', 'P14416', 'O15145', 'Q03426', 'Q08289', 'P49356', 'P42261', 'Q14643', 'P25101', 'O00180', 'Q13564', 'Q16665', 'Q16515', 'P06731', 'P33151', 'P01116', 'P08236', 'Q92753', 'P01178', 'P05120', 'O94768', 'O14983', 'P49763', 'Q16760', 'Q00987', 'P14770', 'P15509', 'P31749', 'Q15842', 'O94925', 'P06732', 'P62745', 'P37023', 'P46459', 'P00813', 'P47872', 'P53611', 'P07814', 'P30518', 'P09417', 'Q92731', 'Q7Z4W1', 'P35869', 'Q9HCP0', 'Q99418', 'P17213', 'O14786', 'P61313', 'Q9UHY7', 'O76074', 'P05089', 'P08620', 'P09238', 'P32297', 'P08263', 'P30405', 'P36871', 'P23975', 'P62258', 'P49368', 'P19883', 'P60880', 'Q16654', 'O76083', 'Q9NRF9', 'Q9NPA2', 'P24530', 'P08172', 'P15056', 'P17931', 'P07998', 'P25100', 'P02760', 'Q07343', 'Q99497', 'Q04828', 'Q04771', 'P01215', 'Q07699', 'Q16651', 'P12268', 'O60678', 'Q8WTS6', 'O14874', 'Q15800', 'Q05397', 'P47871', 'P43403', 'Q9UHA3', 'P25788', 'Q92781', 'P01857', 'P52209', 'P20645', 'P12259', 'O00244', 'P63092', 'Q9UL54', 'P27169', 'P13688', 'Q9UBN7', 'O43424', 'P30559', 'P59768', 'P45877', 'P39900', 'P10606', 'P30531', 'P07437', 'Q9BYC2', 'P21817', 'Q9UKV0', 'P60900', 'Q8IU85', 'P68104', 'P04275', 'P51151', 'Q9Y5R2', 'P02792', 'P23469', 'O75936', 'P35247', 'P20248', 'Q5JAM2', 'P06213', 'P23458', 'P25024', 'P35462', 'Q00796', 'P07948', 'Q92769', 'P16519', 'P17174', 'P09958', 'O95169', 'Q8NER1', 'P32019', 'P12319', 'P31949', 'Q9Y678', 'O00222', 'Q92876', 'O14727', 'P05113', 'P23368', 'P00558', 'P15538', 'P61925', 'P55017', 'P61587', 'P17812', 'P05981', 'P06746', 'O14965', 'P16455', 'P35520', 'P29475', 'P21695', 'P52333', 'P27824', 'P00797', 'Q99735', 'P17600', 'Q16671', 'P12081', 'Q9ULK0', 'P63151', 'P31785', 'Q9UMX2', 'P06850', 'Q07864', 'Q15075', 'P08253', 'Q07075', 'P09871', 'Q71U36', 'Q13126', 'P63096', 'P06729', 'Q99250', 'Q9UIC8', 'P60981', 'P49773', 'P42081', 'P50281', 'Q15046', 'P04083', 'Q16790', 'Q15661', 'Q16718', 'P29803', 'P43681', 'Q16539', 'Q01082', 'Q9BWD1', 'Q04759', 'P08700', 'Q9BY49', 'P28072', 'P21964', 'Q96LZ3', 'Q15596', 'P60953', 'P51511', 'Q16762', 'P05556', 'P48735', 'P48552', 'P32320', 'P42263', 'Q14831', 'O75390', 'P0C0L4', 'P55157', 'P10276', 'P49327', 'P49916', 'Q9UBX3', 'Q9UN19', 'P04004', 'O15382', 'P29317', 'P32119', 'O43681', 'P11766', 'O60551', 'Q9NR82', 'P09668', 'P06126', 'Q96SL4', 'P21796', 'P62917', 'P05023', 'P05107', 'P49915', 'P00374', 'P14174', 'Q9UG56', 'P09917', 'Q9HAN9', 'Q13639', 'P32927', 'Q96GA7', 'Q03013', 'P19404', 'Q9Y478', 'P01574', 'P01303', 'Q9Y277', 'P31323', 'P08243', 'Q13255', 'P62136', 'P13051', 'P10144', 'O60264', 'P20701', 'Q9NZV8', 'P29972', 'Q13547', 'P29218', 'P05362', 'P23284', 'P30939', 'P62312', 'Q16881', 'Q13557', 'P48736', 'P16473', 'P07339', 'P56696', 'P05106', 'P01375', 'P01024', 'P35914', 'P35228', 'P05154', 'P26951', 'P35916', 'Q32P28', 'Q13698', 'P45984', 'P05187', 'P04150', 'P51843', 'P22694', 'P00326', 'O15379', 'P53609', 'P30536', 'P02458', 'P68371', 'P62829', 'P06401', 'Q00535', 'P09012', 'P48067', 'P19320', 'P07101', 'Q08828', 'P20941', 'Q13554', 'P43119', 'Q13370', 'O15269', 'O14646', 'P62942', 'P25705', 'O60603', 'O43772', 'P04424', 'P07108', 'P07947', 'Q9Y6M4', 'Q9UQM7', 'P12429', 'P54819', 'P09960', 'O43252', 'Q9Y689', 'P07510', 'P50440', 'P13569', 'P05121', 'P19367', 'P30542', 'O95477', 'P26640', 'P21554', 'P49888', 'P30793', 'Q05586', 'O14832', 'P23921', 'P23946', 'P08107', 'P08173', 'P41145', 'Q53H96', 'P03886', 'P00533', 'P01920', 'Q15120', 'Q99436', 'O43525', 'P00488', 'P11233', 'P49747', 'P17948', 'O75600', 'O00141', 'P11387', 'P24941', 'Q96C86', 'P56524', 'P52732', 'P24386', 'Q9NNX6', 'P78417', 'P30044', 'Q9UNN8', 'Q13427', 'Q9Y275', 'P12821', 'P17540', 'P62750', 'P26447', 'P05164', 'P21397', 'Q96LB9', 'O43526', 'P48029', 'P11137', 'Q14541', 'Q9NPH2', 'Q01668', 'Q02750', 'Q02763', 'Q02643', 'P11511', 'Q14749', 'P23434', 'P31150', 'P23219', 'Q8NEB9', 'Q96GD3', 'P10153', 'P10515', 'P10644', 'P03956', 'P07307', 'O43612', 'P07360', 'P02452', 'Q8IWT1', 'Q92993', 'P34896', 'P49913', 'P31327', 'P16066', 'P04183', 'P05787', 'P45452', 'Q9Y4W6', 'P17707', 'P11802', 'P62826', 'P19793', 'P25103', 'P49821', 'P02776', 'P32239', 'Q16853', 'P49366', 'Q9NR33', 'P25021', 'P07741', 'P14210', 'P43405', 'P41595', 'P40926', 'Q96QT4', 'P30532', 'O76054', 'Q99933', 'P25786', 'P14543', 'P11177', 'P62837', 'Q92793', 'P23297', 'Q5JTZ9', 'P28221', 'P22303', 'P14778', 'O14842', 'P53350', 'P07766', 'P13073', 'P05230', 'P48048', 'P18825', 'P18545', 'Q9UKU7', 'P17252', 'Q14654', 'P25311', 'P61586', 'Q6IA69', 'Q969G6', 'P02585', 'P51149', 'P07355', 'Q07954', 'P30153', 'P22570', 'O14594', 'P47989', 'Q06609', 'P47929', 'P43115', 'O60882', 'P11229', 'P16870', 'P00568', 'Q8NFW8', 'Q9Y296', 'P98170', 'P08069', 'O43708', 'O00206', 'P23528', 'P18085', 'Q00169', 'P01589', 'P27708', 'P30043', 'P46926', 'P04118', 'Q07869', 'Q9UP95', 'P49759', 'O00763', 'P35348', 'P46976', 'P80511', 'P27797', 'P35568', 'P04629', 'Q99928', 'P00709', 'P63208', 'P49286', 'P20339', 'P36897', 'Q14032', 'Q96CD2', 'O14936', 'Q14833', 'Q9NP99', 'P15289', 'P07333', 'P07686', 'Q9BY66', 'P23945', 'P02461', 'P08709', 'P15531', 'P00492', 'P30041', 'Q96KS0', 'Q9P2J5', 'Q14416', 'Q93088', 'Q9NY65', 'P17538', 'P18507', 'P31645', 'P36959', 'P47712', 'Q9NYX4', 'P10745', 'P05177', 'P54317', 'P02745', 'P02747', 'Q9ULA0', 'P14784', 'P08174', 'P53597', 'Q9ULZ9', 'P28222', 'P22033', 'P30926', 'P07237', 'P10646', 'P00387', 'Q99798', 'P78352', 'O00329', 'P08758', 'P43003', 'Q03405', 'P08729', 'P42658', 'P05231', 'Q16099', 'O75116', 'O60840', 'P00519', 'P25116', 'P45880', 'P08581', 'P35080', 'P25815', 'O75899', 'P30988', 'P36021', 'O60701', 'P35222', 'O43776', 'O14920', 'P07858', 'Q00536', 'P13716', 'O60894', 'P17342', 'O75676', 'Q14626', 'Q99584', 'P32418', 'P20231', 'Q9NXE4', 'P35243', 'P02743', 'P28074', 'P11717', 'Q9UBT6', 'P23763', 'P09104', 'O14582', 'P01130', 'P15813', 'P26358', 'P51649', 'P08631', 'P26196', 'Q96GD0', 'P21462', 'O00408', 'P22607', 'Q14832', 'P48551', 'P08519', 'P62330', 'P36873', 'P20594', 'P28331', 'P49419', 'Q15788', 'P48443', 'P02649', 'P07602', 'O14964', 'P40394', 'O95907', 'P80188', 'P04406', 'P35749', 'P36222', 'Q16873', 'P00352', 'P80192', 'O43617', 'P78362', 'P34903', 'Q08426', 'P12236', 'O95573', 'Q9BZ11', 'P16234', 'P14061', 'P06396', 'Q13464', 'P08123', 'O15144', 'P47870', 'P14410', 'P16220', 'P10109', 'Q8IVH4', 'P35030', 'P21731', 'Q9NPG2', 'P04350', 'Q9BZX2', 'P21266', 'P14867', 'P50613', 'Q13085', 'P30291', 'Q15078', 'P06276', 'Q9P0Z9', 'O15427', 'Q15796', 'P30084', 'P35398', 'P15151', 'P09455', 'P16233', 'P19440', 'P07204', 'P55055', 'P11310', 'Q03393', 'P01031', 'P48147', 'P05062', 'P41180', 'P28472', 'P12883', 'Q9UBK8', 'P43116', 'O14717', 'Q08345', 'Q00653', 'P02749', 'Q16836', 'Q03403', 'P07737', 'Q9UK23', 'P21802', 'Q03181', 'Q99661', 'P37231', 'P31151', 'P61457', 'P14550', 'P15502', 'P52788', 'P22061', 'O94766', 'Q16222', 'P16471', 'P06730', 'P22314', 'P14555', 'O60664', 'P09669', 'P40429', 'P17936', 'P62508', 'P50406', 'P78368', 'P62993', 'O95182', 'P18505', 'Q9BY41', 'Q9UJ70', 'Q08499', 'P78348', 'P69905', 'P01009', 'Q96C24', 'P28066', 'P61769', 'P49720', 'P50607', 'P49137', 'P48167', 'P16581', 'P29508', 'P09693', 'P10114', 'P32745', 'P01189', 'P32238', 'P09467', 'P30411', 'P49146', 'P23258', 'P07202', 'P04156', 'P13929', 'Q9Y233', 'Q99720', 'O43293', 'Q01718', 'P14927', 'Q6P996', 'P05162', 'Q96C36', 'P35557', 'O60568', 'P16444', 'Q01959', 'P63027', 'Q15758', 'P11217', 'Q9Y6I3', 'P55769', 'Q13936', 'P24347', 'Q9NSA0', 'Q13393', 'P43004', 'P08913', 'P22888', 'Q08209', 'P04637', 'Q9NRX3', 'P07359', 'O15511', 'Q15274', 'P80075', 'Q92831', 'Q07817', 'Q13555', 'P07585', 'P53985', 'Q5VZ30', 'P20963', 'P09172', 'P49841', 'P17302', 'P11172', 'O00305', 'P49257', 'P53778', 'P19801', 'P11362', 'P20333', 'P10696', 'P15144', 'P62873', 'Q9UPY5', 'Q9UK17', 'Q9HCD5', 'P14625', 'Q13093', 'Q12866', 'P35968', 'P06239', 'O14764', 'O15111', 'P05165', 'P20810', 'Q9P2R7', 'O14788', 'P28702', 'P11498', 'P36542', 'Q15382', 'P10620', 'Q9BVJ7', 'O43451', 'P30556', 'Q15119', 'P49789', 'P11309', 'P00747', 'P04075', 'P11413', 'P52961', 'P03951', 'P51955', 'P08887', 'O14757', 'P78536', 'O60494', 'P08559', 'P05452', 'O95865', 'P61158', 'P36896', 'P33176', 'Q96RI1', 'P13501', 'P28335', 'P07098', 'Q16566', 'P02671', 'P02818', 'P49585', 'P06858', 'P39023', 'P07477', 'P27815', 'P09237', 'O43447', 'P62491', 'P00451', 'P14920', 'P02778', 'P42684', 'P48058', 'Q8N9I0', 'O60895', 'P06737', 'Q9UNA0', 'Q9UL51', 'Q10472', 'P16860', 'Q16775', 'P0C0L5', 'P00167', 'P20138'}
Missing proteins: {'DB00861', 'DB01320', 'DB00367', 'DB00602', 'DB00972', 'DB01059', 'DB00485', 'DB08860', 'DB00547', 'DB00375', 'DB01400', 'DB00650', 'DB00186', 'DB01050', 'DB01117', 'DB00535', 'DB04839', 'DB00481', 'DB00395', 'DB00486', 'DB00881', 'DB00507', 'DB00994', 'DB00359', 'DB01097', 'DB00198', 'DB01041', 'DB00213', 'DB00455', 'DB00897', 'DB00773', 'DB01095', 'DB00590', 'DB00619', 'DB00404', 'DB01362', 'DB00537', 'DB00905', 'DB00188', 'DB01586', 'DB00398', 'DB00444', 'DB00899', 'DB00500', 'DB00750', 'DB00697', 'DB00731', 'DB00672', 'DB01005', 'DB01012', 'DB01261', 'DB00302', 'DB01118', 'DB00184', 'DB00857', 'DB01260', 'DB00227', 'DB00851', 'DB00490', 'DB00678', 'DB00713', 'DB00533', 'DB00674', 'DB00983', 'DB00701', 'DB01032', 'DB00543', 'DB00967', 'DB01364', 'DB00461', 'DB00457', 'DB00826', 'DB01083', 'DB00591', 'DB01114', 'DB00263', 'DB00996', 'DB00680', 'DB00152', 'DB00927', 'DB00439', 'DB00542', 'DB01340', 'DB01086', 'DB00685', 'DB01156', 'DB00975', 'DB00204', 'DB00734', 'DB00963', 'DB00671', 'DB00898', 'DB00323', 'DB01072', 'DB01030', 'DB00280', 'DB00211', 'DB01206', 'DB06800', 'DB00595', 'DB00408', 'DB01106', 'DB00421', 'DB00343', 'DB01048', 'DB00796', 'DB00612', 'DB01151', 'DB01211', 'DB00242', 'DB00980', 'DB00317', 'DB01076', 'DB01427', 'DB00197', 'DB01088', 'DB00968', 'DB00415', 'DB01120', 'DB00696', 'DB06287', 'DB00835', 'DB00912', 'DB00381', 'DB06209', 'DB01026', 'DB00921', 'DB00973', 'DB01127', 'DB01198', 'DB01223', 'DB00806', 'DB00668', 'DB00649', 'DB00199', 'DB00615', 'DB01112', 'DB00261', 'DB01183', 'DB00276', 'DB00759', 'DB00254', 'DB00191', 'DB00185', 'DB01019', 'DB00520', 'DB01233', 'DB01241', 'DB00811', 'DB00693', 'DB01119', 'DB00182', 'DB00623', 'DB06695', 'DB00418', 'DB00446', 'DB01136', 'DB00951', 'DB00605', 'DB00363', 'DB01319', 'DB01230', 'DB01280', 'DB00548', 'DB00819', 'DB00287', 'DB04845', 'DB00593', 'DB00907', 'DB00341', 'DB00571', 'DB00582', 'DB00950', 'DB00368', 'DB00988', 'DB00758', 'DB00299', 'DB00749', 'DB00673', 'DB00530', 'DB00413', 'DB00339', 'DB00933', 'DB01205', 'DB01200', 'DB00289', 'DB00753', 'DB00494', 'DB00284', 'DB00491', 'DB01169', 'DB00625', 'DB00401', 'DB01627', 'DB00651', 'DB01329', 'DB00814', 'DB00646', 'DB00231', 'DB00876', 'DB00961', 'DB01168', 'DB01203', 'DB00844', 'DB00776', 'DB00288', 'DB00706', 'DB01162', 'DB01098', 'DB01331', 'DB01035', 'DB00594', 'DB00442', 'DB00283', 'DB01618', 'DB00268', 'DB00938', 'DB04861', 'DB00573', 'DB00656', 'DB00915', 'DB01190', 'DB00337', 'DB00264', 'DB01337', 'DB06212', 'DB00330', 'DB01330', 'DB00203', 'DB01100', 'DB00959', 'DB00620', 'DB01130', 'DB00549', 'DB00953', 'DB01143', 'DB00937', 'DB01084', 'DB00908', 'DB00296', 'DB00641', 'DB00909', 'DB00331', 'DB01393', 'DB00879', 'DB01254', 'DB00863', 'DB00393', 'DB00313', 'DB00896', 'DB00344', 'DB01220', 'DB01167', 'DB01263', 'DB00986', 'DB00555', 'DB00631', 'DB00952', 'DB00297', 'DB01378', 'DB01577', 'DB00550', 'DB00708', 'DB00989', 'DB00497', 'DB00224', 'DB00683', 'DB00482', 'DB00477', 'DB00312', 'DB00478', 'DB00338', 'DB01165', 'DB01011', 'DB06710', 'DB00293', 'DB00515', 'DB01217', 'DB00561', 'DB01356', 'DB01018', 'DB00377', 'DB00400', 'DB00799', 'DB00479', 'DB01656', 'DB00640', 'DB00362', 'DB00207', 'DB00911', 'DB00839', 'DB00869', 'DB01610', 'DB00271', 'DB00757', 'DB00700', 'DB00277', 'DB00382', 'DB00318', 'DB01250', 'DB00621', 'DB01115', 'DB01148', 'DB01132', 'DB00964', 'DB00874', 'DB01242', 'DB00370', 'DB00195', 'DB00349', 'DB00654', 'DB01273', 'DB00818', 'DB01248', 'DB00246', 'DB00836', 'DB01158', 'DB00356', 'DB00519', 'DB01029', 'DB01182', 'DB00966', 'DB00282', 'DB01212', 'DB00730', 'DB00316', 'DB00531', 'DB00252', 'DB00235', 'DB00431', 'DB00782', 'DB00958', 'DB00412', 'DB00586', 'DB01161', 'DB00218', 'DB00860', 'DB00581', 'DB01258', 'DB00472', 'DB00630', 'DB00361', 'DB00396', 'DB01221', 'DB00841', 'DB00742', 'DB00870', 'DB06201', 'DB01367', 'DB00564', 'DB00200', 'DB06228', 'DB01023', 'DB00601', 'DB01044', 'DB00920', 'DB01064', 'DB00422', 'DB01149', 'DB00703', 'DB00388', 'DB05246', 'DB00364', 'DB01174', 'DB00828', 'DB00177', 'DB01609', 'DB00492', 'DB01060', 'DB01006', 'DB01236', 'DB00484', 'DB00480', 'DB00783', 'DB00903', 'DB06335', 'DB00176', 'DB00705', 'DB00842', 'DB00458', 'DB00993', 'DB00611', 'DB00794', 'DB00979', 'DB01105', 'DB00756', 'DB00698', 'DB00247', 'DB00449', 'DB04572', 'DB00450', 'DB00772', 'DB00704', 'DB00495', 'DB00732', 'DB01210', 'DB04896', 'DB01214', 'DB00423', 'DB00762', 'DB06204', 'DB01196', 'DB01140', 'DB01142', 'DB01218', 'DB00369', 'DB00201', 'DB00187', 'DB00371', 'DB00222', 'DB01215', 'DB00417', 'DB00900', 'DB01188', 'DB00441', 'DB00618', 'DB00889', 'DB01039', 'DB00281', 'DB00557', 'DB00315', 'DB01623', 'DB01390', 'DB06216', 'DB01058', 'DB01113', 'DB00238', 'DB01238', 'DB01024', 'DB00598', 'DB00193', 'DB00795', 'DB00572', 'DB00384', 'DB00162', 'DB00502', 'DB01224', 'DB00585', 'DB00827', 'DB00872', 'DB00324', 'DB00622', 'DB00303', 'DB00820', 'DB00310', 'DB00524', 'DB00285', 'DB01067', 'DB00779', 'DB00196', 'DB00273', 'DB00208', 'DB00501', 'DB01327', 'DB00321', 'DB00503', 'DB00426', 'DB00883', 'DB01069', 'DB00541', 'DB00981', 'DB00675', 'DB01013', 'DB01612', 'DB00609', 'DB00718', 'DB00850', 'DB00563', 'DB00327', 'DB01406', 'DB00539', 'DB00180', 'DB00580', 'DB00440', 'DB01128', 'DB00813', 'DB00918', 'DB01073', 'DB00175', 'DB01558', 'DB00308', 'DB00319', 'DB00910', 'DB00575', 'DB00205', 'DB01413', 'DB08826', 'DB00532', 'DB00529', 'DB00962', 'DB00050', 'DB05245', 'DB01129', 'DB00695', 'DB00924', 'DB00843', 'DB01042', 'DB01195', 'DB00376', 'DB01008', 'DB00438', 'DB00745', 'DB00949', 'DB00489', 'DB00978', 'DB00991', 'DB00714', 'DB01186', 'DB00567', 'DB00228', 'DB01085', 'DB00243', 'DB00694', 'DB00712', 'DB01249', 'DB06402', 'DB01104', 'DB01410', 'DB00374', 'DB00733', 'DB00220', 'DB00780', 'DB00328', 'DB00475', 'DB00465', 'DB00682', 'DB00496', 'DB00652', 'DB00679', 'DB00887', 'DB00351', 'DB01068', 'DB01204', 'DB00763', 'DB00724', 'DB00355', 'DB00916', 'DB00350', 'DB00460', 'DB00665', 'DB00690', 'DB01101', 'DB00635', 'DB00932', 'DB01264', 'DB01181', 'DB00560', 'DB01247', 'DB01324', 'DB01226', 'DB00906', 'DB00904', 'DB00425', 'DB00419', 'DB01062', 'DB01416', 'DB00178', 'DB00804', 'DB01595', 'DB00659', 'DB01150', 'DB00768', 'DB01141', 'DB00554', 'DB00727', 'DB01079', 'DB00335', 'DB00257', 'DB00751', 'DB01110', 'DB00357', 'DB00740', 'DB00833', 'DB00657', 'DB00681', 'DB01189', 'DB01232', 'DB00210', 'DB01394', 'DB00181', 'DB01219', 'DB00691', 'DB00721', 'DB01185', 'DB00788', 'DB00437', 'DB00878', 'DB01591', 'DB00864', 'DB00291', 'DB00278', 'DB01611', 'DB00559', 'DB01194', 'DB00433', 'DB00216', 'DB01082', 'DB00434', 'DB01046', 'DB00728', 'DB00831', 'DB00834', 'DB00871', 'DB01009', 'DB01222', 'DB01017', 'DB01275', 'DB01159', 'DB00829', 'DB00320', 'DB06811', 'DB00346', 'DB00333', 'DB00929', 'DB01184', 'DB00648', 'DB01216', 'DB00454', 'DB01268', 'DB06698', 'DB00642', 'DB01037', 'DB00853', 'DB00810', 'DB01004', 'DB00540', 'DB00476', 'DB00925', 'DB06274', 'DB00684', 'DB00230', 'DB01177', 'DB00248', 'DB00976', 'DB01173', 'DB00292', 'DB00526', 'DB00802', 'DB01087', 'DB01047', 'DB01332', 'DB00715', 'DB06711', 'DB01036', 'DB01291', 'DB00845', 'DB00373', 'DB01080', 'DB00809', 'DB01421', 'DB00661', 'DB00692', 'DB00411', 'DB00784', 'DB00521', 'DB00250', 'DB00538', 'DB00633', 'DB00603', 'DB00608', 'DB00669', 'DB00800', 'DB00998', 'DB01409', 'DB00358', 'DB00790', 'DB00862', 'DB00687', 'DB00518', 'DB00202', 'DB00295', 'DB00624', 'DB00474', 'DB01234', 'DB00471', 'DB00215', 'DB00206', 'DB01014', 'DB00919', 'DB01201', 'DB01229', 'DB01157', 'DB04844', 'DB01126', 'DB00307', 'DB04930', 'DB00459', 'DB00822', 'DB00493', 'DB01075', 'DB00499', 'DB00997', 'DB00868', 'DB00808', 'DB00990', 'DB00334', 'DB00960', 'DB01193', 'DB00578', 'DB00448', 'DB00754', 'DB00390', 'DB00709', 'DB00969', 'DB01043', 'DB00999', 'DB00379'}
In [203]:
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import SGDClassifier
from tqdm import tqdm
# Train logistic regression model
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)
clf = SGDClassifier(loss='log_loss', max_iter=1, tol=None, warm_start=True, random_state=42)
# Custom training loop with progress bar
n_iterations = 1000 # Number of iterations for the progress bar
chunk_size = len(X_train_scaled) // n_iterations
for _ in tqdm(range(n_iterations), desc="Training progress"):
indices = np.random.choice(len(X_train_scaled), chunk_size, replace=False)
clf.partial_fit(X_train_scaled[indices], y_train[indices], classes=np.unique(y_train))
Training progress: 100%|███████████████████████████████████████████████████████████| 1000/1000 [00:14<00:00, 70.33it/s]
Training completed
In [204]:
# Save the model
import joblib
model_path = 'DTINet/data/predictionModel_SGDC.pkl'
joblib.dump(clf, model_path)
print(f"Model saved to {model_path}")
Model saved to DTINet/data/predictionModel_SGDC.pkl
In [197]:
# Load the model
import joblib
model_path = 'DTINet/data/predictionModel_SGDC.pkl'
clf = joblib.load(model_path)
print("Model loaded successfully")
Model loaded successfully
In [205]:
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score, average_precision_score, accuracy_score, recall_score, f1_score, roc_curve, precision_recall_curve
# Predict potential links
y_pred_proba = clf.predict_proba(X_test_scaled)[:, 1]
y_pred = clf.predict(X_test_scaled)
# Evaluation metrics
roc_auc = roc_auc_score(y_test, y_pred_proba)
average_precision = average_precision_score(y_test, y_pred_proba)
accuracy = accuracy_score(y_test, y_pred)
recall = recall_score(y_test, y_pred)
f1 = f1_score(y_test, y_pred)
# Print the evaluation metrics
print(f"ROC-AUC: {roc_auc:.4f}")
print(f"Average Precision: {average_precision:.4f}")
print(f"Accuracy: {accuracy:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1:.4f}")
# Plot ROC curve
fpr, tpr, _ = roc_curve(y_test, y_pred_proba)
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic (ROC) Curve')
plt.legend(loc="lower right")
# Plot Precision-Recall curve
precision, recall, _ = precision_recall_curve(y_test, y_pred_proba)
plt.subplot(1, 2, 2)
plt.plot(recall, precision, color='blue', lw=2, label=f'Precision-Recall curve (area = {average_precision:.2f})')
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Precision-Recall Curve')
plt.legend(loc="lower left")
plt.suptitle('SGDC Model Evaluation Metrics', fontsize=16)
plt.tight_layout(rect=[0, 0.03, 1, 0.95])
plt.savefig('DTINet/plots/_model_evaluation.png', format='png', dpi=300)
plt.show()
ROC-AUC: 0.6036 Average Precision: 0.0325 Accuracy: 0.9962 Recall: 0.0000 F1 Score: 0.0000
In [206]:
# Predict Potential Links
y_pred = clf.predict(X_test)
predicted_links = [(pairs_test[i][0], pairs_test[i][1]) for i in range(len(y_pred)) if y_pred[i] == 1 and y_test[i] == 0]
print("Predicted new links:")
for link in predicted_links:
print(link)
Predicted new links:
In [207]:
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
# Create a new graph for predicted links
G_predicted = nx.Graph()
# Add predicted nodes and edges
for u, v in predicted_links:
G_predicted.add_node(u, bipartite=0)
G_predicted.add_node(v, bipartite=1)
G_predicted.add_edge(u, v)
# All edges in data
all_known_edges = [(ls_drug[i], ls_protein[j]) for i, j in zip(*np.where(np_mat_drug_protein == 1))]
# Add known edges between the predicted nodes
known_edges = [(u, v) for u, v in predicted_links if np_mat_drug_protein[ls_drug.index(u)][ls_protein.index(v)] == 1]
# Identify nodes involved in predicted edges
predicted_nodes = set([u for u, v in predicted_links] + [v for u, v in predicted_links])
# Identify all known edges connected to predicted nodes but not in predicted edges
additional_edges = [
(u, v) for u, v in all_known_edges
if (u in predicted_nodes or v in predicted_nodes) and (u, v) not in predicted_edges and (u, v) not in known_edges
]
# Ensure all nodes in additional_edges are added to the graph
for u, v in additional_edges:
if u not in G_predicted:
G_predicted.add_node(u, bipartite=0 if u in ls_drug else 1)
if v not in G_predicted:
G_predicted.add_node(v, bipartite=0 if v in ls_drug else 1)
# Draw the predicted links graph
plt.figure(figsize=(20, 10))
pos = nx.spring_layout(G_predicted, k=0.1, iterations=5) # Adjust k for node spacing
# Draw known edges in green
nx.draw_networkx_edges(G_predicted, pos, edgelist=known_edges, edge_color='green', width=2)
# Draw predicted links in red
predicted_edges = [(u, v) for u, v in predicted_links if np_mat_drug_protein[ls_drug.index(u)][ls_protein.index(v)] == 0]
nx.draw_networkx_edges(G_predicted, pos, edgelist=predicted_edges, edge_color='red', width=2)
# Draw additional known edges in black
nx.draw_networkx_edges(G_predicted, pos, edgelist=additional_edges, edge_color='gray', width=0.2)
# Draw nodes with different colors for drugs and proteins
drug_nodes = [node for node in G_predicted.nodes if node in ls_drug]
protein_nodes = [node for node in G_predicted.nodes if node in ls_protein]
nx.draw_networkx_nodes(G_predicted, pos, nodelist=drug_nodes, node_color='lightgreen', node_size=700, label='Drugs')
nx.draw_networkx_nodes(G_predicted, pos, nodelist=protein_nodes, node_color='pink', node_size=700, label='Proteins')
# Draw labels
nx.draw_networkx_labels(G_predicted, pos, font_size=6, font_family='sans-serif')
plt.title("Predicted Drug-Protein Interaction Network")
plt.legend(markerscale=0.35)
plt.savefig('DTINet/plots/_predicted_drug_protein_interaction_network.png', format='png', dpi=300)
plt.show()
No artists with labels found to put in legend. Note that artists whose label start with an underscore are ignored when legend() is called with no argument.
Link prediction with GNC¶
In [266]:
# Visualize a sample of network
import random
import networkx as nx
import numpy as np
# Set the size of the sample set
setSize = 10
similarityTreshold = 0.7
# Select drugs for the sample set
# sample_drugs = random.sample(ls_drug, setSize)
sample_drugs = ls_drug[:setSize]
# sample_drugs = ls_drug[-setSize:]
# Create a new graph for the sample set
G_sample = nx.Graph()
# Add nodes for drugs, proteins, diseases, and side effects
for drug in sample_drugs:
G_sample.add_node(drug, bipartite=0, color='blue') # Drug nodes in lightgreen
for protein in ls_protein:
if np_mat_drug_protein[ls_drug.index(drug), ls_protein.index(protein)] == 1:
G_sample.add_node(protein, bipartite=1, color='red') # Protein nodes in pink
G_sample.add_edge(drug, protein, color='red') # Drug-protein edges in green
for disease in ls_disease:
if np_mat_drug_disease[ls_drug.index(drug), ls_disease.index(disease)] == 1:
G_sample.add_node(disease, bipartite=2, color='orange') # Disease nodes in orange
G_sample.add_edge(drug, disease, color='orange') # Drug-disease edges in green
for side_effect in ls_sideEffect:
if np_mat_drug_sideEffect[ls_drug.index(drug), ls_sideEffect.index(side_effect)] == 1:
G_sample.add_node(side_effect, bipartite=3, color='gray') # Side effect nodes in gray
G_sample.add_edge(drug, side_effect, color='gray') # Drug-side effect edges in green
# Add similarity edges between drugs
for i, drug1 in enumerate(sample_drugs):
for j, drug2 in enumerate(sample_drugs):
if i != j:
similarity = np_mat_simDrugs[ls_drug.index(drug1), ls_drug.index(drug2)]
if similarity > similarityTreshold:
G_sample.add_edge(drug1, drug2, color='blue', weight=similarity*2)
# Plot the graph
plt.figure(figsize=(20, 10))
pos = nx.spring_layout(G_sample, k=1, iterations=2)
edge_colors = [G_sample[u][v]['color'] for u, v in G_sample.edges()]
node_colors = [G_sample.nodes[n].get('color', 'yellow') for n in G_sample.nodes()]
nx.draw(G_sample, pos, with_labels=False, node_color=node_colors, edge_color=edge_colors, width=0.33, node_size=50)
plt.title("Graph with Sample Drug Set")
plt.text(0.95, 0.95, 'Drugs', color='blue', fontsize=10, transform=plt.gca().transAxes)
plt.text(0.95, 0.90, 'Proteins', color='red', fontsize=10, transform=plt.gca().transAxes)
plt.text(0.95, 0.85, 'Diseases', color='orange', fontsize=10, transform=plt.gca().transAxes)
plt.text(0.95, 0.80, 'Side Effects', color='gray', fontsize=10, transform=plt.gca().transAxes)
plt.savefig('DTINet/plots/_sample_drugset_interaction_network.png', format='png', dpi=300)
plt.show()
Unified graph¶
In [267]:
import networkx as nx
import numpy as np
# Create a unified graph
G = nx.Graph()
# Add nodes for drugs, proteins, diseases, and side effects
for drug in ls_drug:
G.add_node(drug, node_type='drug')
for protein in ls_protein:
G.add_node(protein, node_type='protein')
for disease in ls_disease:
G.add_node(disease, node_type='disease')
for side_effect in ls_sideEffect:
G.add_node(side_effect, node_type='side_effect')
# Add edges for drug-protein interactions
for i, drug in enumerate(ls_drug):
for j, protein in enumerate(ls_protein):
if np_mat_drug_protein[i, j] == 1:
G.add_edge(drug, protein, edge_type='drug_protein')
# Add edges for drug-disease interactions
for i, drug in enumerate(ls_drug):
for j, disease in enumerate(ls_disease):
if np_mat_drug_disease[i, j] == 1:
G.add_edge(drug, disease, edge_type='drug_disease')
# Add edges for drug-side effect interactions
for i, drug in enumerate(ls_drug):
for j, side_effect in enumerate(ls_sideEffect):
if np_mat_drug_sideEffect[i, j] == 1:
G.add_edge(drug, side_effect, edge_type='drug_side_effect')
# Add edges for drug-drug similarities
for i, drug1 in enumerate(ls_drug):
for j, drug2 in enumerate(ls_drug):
if i != j and np_mat_simDrugs[i, j] > similarityTreshold:
G.add_edge(drug1, drug2, edge_type='drug_drug_sim', weight=np_mat_simDrugs[i, j])
# Add edges for protein-protein similarities
for i, protein1 in enumerate(ls_protein):
for j, protein2 in enumerate(ls_protein):
if i != j and np_mat_simProteins[i, j] > similarityTreshold:
G.add_edge(protein1, protein2, edge_type='protein_protein_sim', weight=np_mat_simProteins[i, j])
# Print the number of nodes and edges in the unified graph
print(f'Unified Graph: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges')
Unified Graph: 11317 nodes, 1384665 edges
Feature extraction with GCN¶
In [270]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data
from tqdm import tqdm
# Create the feature matrix (we'll use one-hot encoding for simplicity)
features = np.eye(G.number_of_nodes())
# Create the edge index for the graph
edge_index = []
for edge in G.edges:
node1, node2 = edge
edge_index.append([list(G.nodes).index(node1), list(G.nodes).index(node2)])
edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
# Convert the feature matrix to a tensor
x = torch.tensor(features, dtype=torch.float)
# Create a PyTorch Geometric data object
data = Data(x=x, edge_index=edge_index)
class GCN(torch.nn.Module):
def __init__(self):
super(GCN, self).__init__()
self.conv1 = GCNConv(data.num_features, 16)
self.conv2 = GCNConv(16, 2) # Output 2 features for each node
def forward(self, data):
x, edge_index = data.x, data.edge_index
x = self.conv1(x, edge_index)
x = F.relu(x)
x = self.conv2(x, edge_index)
return x
# Initialize the model, optimizer, and loss function
model = GCN()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()
# Training the GCN
def train():
model.train()
optimizer.zero_grad()
out = model(data)
loss = criterion(out, torch.tensor([G.nodes[node]['node_type'] == 'drug' for node in G.nodes], dtype=torch.long))
loss.backward()
optimizer.step()
return loss.item()
# Training loop with progress bar
epochs = 200
for epoch in tqdm(range(epochs), desc="Training GCN"):
loss = train()
if epoch % 10 == 0:
print(f'Epoch {epoch}, Loss: {loss:.4f}')
# Extract node features
model.eval()
with torch.no_grad():
node_features = model(data).numpy()
print("Node features extracted using GCN.")
Training GCN: 0%|▎ | 1/200 [00:03<12:13, 3.69s/it]
Epoch 0, Loss: 0.6980
Training GCN: 6%|███▋ | 11/200 [00:07<01:27, 2.17it/s]
Epoch 10, Loss: 0.2608
Training GCN: 10%|███████ | 21/200 [00:11<01:13, 2.45it/s]
Epoch 20, Loss: 0.1717
Training GCN: 16%|██████████▍ | 31/200 [00:15<01:06, 2.56it/s]
Epoch 30, Loss: 0.1642
Training GCN: 20%|█████████████▋ | 41/200 [00:19<01:03, 2.51it/s]
Epoch 40, Loss: 0.1474
Training GCN: 26%|█████████████████ | 51/200 [00:23<00:59, 2.51it/s]
Epoch 50, Loss: 0.1334
Training GCN: 30%|████████████████████▍ | 61/200 [00:27<00:55, 2.51it/s]
Epoch 60, Loss: 0.1255
Training GCN: 36%|███████████████████████▊ | 71/200 [00:31<00:51, 2.52it/s]
Epoch 70, Loss: 0.1188
Training GCN: 40%|███████████████████████████▏ | 81/200 [00:35<00:47, 2.49it/s]
Epoch 80, Loss: 0.1126
Training GCN: 46%|██████████████████████████████▍ | 91/200 [00:40<00:51, 2.14it/s]
Epoch 90, Loss: 0.1072
Training GCN: 50%|█████████████████████████████████▎ | 101/200 [00:44<00:39, 2.50it/s]
Epoch 100, Loss: 0.1024
Training GCN: 56%|████████████████████████████████████▋ | 111/200 [00:48<00:36, 2.45it/s]
Epoch 110, Loss: 0.0980
Training GCN: 60%|███████████████████████████████████████▉ | 121/200 [00:52<00:30, 2.55it/s]
Epoch 120, Loss: 0.0939
Training GCN: 66%|███████████████████████████████████████████▏ | 131/200 [00:56<00:27, 2.55it/s]
Epoch 130, Loss: 0.0902
Training GCN: 70%|██████████████████████████████████████████████▌ | 141/200 [01:00<00:23, 2.53it/s]
Epoch 140, Loss: 0.0866
Training GCN: 76%|█████████████████████████████████████████████████▊ | 151/200 [01:04<00:19, 2.55it/s]
Epoch 150, Loss: 0.0833
Training GCN: 80%|█████████████████████████████████████████████████████▏ | 161/200 [01:08<00:15, 2.54it/s]
Epoch 160, Loss: 0.0801
Training GCN: 86%|████████████████████████████████████████████████████████▍ | 171/200 [01:12<00:11, 2.48it/s]
Epoch 170, Loss: 0.0771
Training GCN: 90%|███████████████████████████████████████████████████████████▋ | 181/200 [01:15<00:07, 2.55it/s]
Epoch 180, Loss: 0.0743
Training GCN: 96%|███████████████████████████████████████████████████████████████ | 191/200 [01:19<00:03, 2.56it/s]
Epoch 190, Loss: 0.0715
Training GCN: 100%|██████████████████████████████████████████████████████████████████| 200/200 [01:23<00:00, 2.40it/s]
Node features extracted using GCN.
In [271]:
# Save the trained model
torch.save(model.state_dict(), 'DTINet/data/gcn_model.pth')
# Save the extracted node features
np.save('DTINet/data/gcn_model_node_features.npy', node_features)
print("GCN model and features are saved successfully.")
GCN model and features are saved successfully.
In [ ]:
# Load the saved model state
model = GCN()
model.load_state_dict(torch.load('DTINet/data/gcn_model.pth'))
node_features = np.load('DTINet/data/gcn_model_node_features.npy')
print("GCN model and features are loaded successfully.")
In [ ]:
import random
# Positive samples: existing edges in the graph
positive_samples = []
for edge in G.edges:
positive_samples.append((list(G.nodes).index(edge[0]), list(G.nodes).index(edge[1])))
# Negative samples: node pairs with no edge between them
negative_samples = []
while len(negative_samples) < len(positive_samples):
node1 = random.choice(list(G.nodes))
node2 = random.choice(list(G.nodes))
if not G.has_edge(node1, node2):
negative_samples.append((list(G.nodes).index(node1), list(G.nodes).index(node2)))
# Labels for the samples
y_positive = [1] * len(positive_samples)
y_negative = [0] * len(negative_samples)
# Combine positive and negative samples
X_samples = positive_samples + negative_samples
y_samples = y_positive + y_negative
# Convert to numpy arrays
X_samples = np.array(X_samples)
y_samples = np.array(y_samples)
print("Positive and negative samples constructed for supervised learning.")
In [356]:
print("Node Features Shape:", node_features.shape)
print("Labels Shape:", labels.shape)
Node Features Shape: (11317, 2) Labels Shape: (1070496,)
Link prediction model - LR¶
In [350]:
from sklearn.linear_model import LogisticRegression
# Split data into train and test sets
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X_samples, y_samples, test_size=0.2, random_state=42)
# Initialize and train a logistic regression model
logreg = LogisticRegression()
logreg.fit(X_train, y_train)
Out[350]:
LogisticRegression()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
LogisticRegression()
In [357]:
# Predict Potential Links
y_pred = logreg.predict(X_test)
# Extract predicted links where the predicted label is positive and the actual label is negative
gcn_lr_predicted_links = [(pairs_test[i][0], pairs_test[i][1]) for i in range(len(y_pred)) if y_pred[i] == 1 and y_test[i] == 0]
# Print predicted new links
print("Predicted new links:")
for link in gcn_lr_predicted_links:
print(link)
--------------------------------------------------------------------------- IndexError Traceback (most recent call last) Cell In[357], line 5 2 y_pred = logreg.predict(X_test) 4 # Extract predicted links where the predicted label is positive and the actual label is negative ----> 5 gcn_lr_predicted_links = [(pairs_test[i][0], pairs_test[i][1]) for i in range(len(y_pred)) if y_pred[i] == 1 and y_test[i] == 0] 7 # Print predicted new links 8 print("Predicted new links:") Cell In[357], line 5, in <listcomp>(.0) 2 y_pred = logreg.predict(X_test) 4 # Extract predicted links where the predicted label is positive and the actual label is negative ----> 5 gcn_lr_predicted_links = [(pairs_test[i][0], pairs_test[i][1]) for i in range(len(y_pred)) if y_pred[i] == 1 and y_test[i] == 0] 7 # Print predicted new links 8 print("Predicted new links:") IndexError: list index out of range
In [277]:
# Save the model
import joblib
model_path = 'DTINet/data/gcn_model_predictionModel_LR.pkl'
joblib.dump(logreg, model_path)
print(f"Model saved to {model_path}")
Model saved to DTINet/data/gcn_model_predictionModel_LR.pkl
In [ ]:
# Load the model
import joblib
model_path = 'DTINet/data/gcn_model_predictionModel_LR.pkl'
logreg = joblib.load(model_path)
print("Model loaded successfully")
GCN_LR model evaluation¶
In [ ]:
from sklearn.metrics import roc_auc_score, average_precision_score, accuracy_score, precision_score, recall_score, f1_score
# Predict probabilities for the test set
y_pred_proba = logreg.predict_proba(X_test)[:, 1]
# Calculate AUROC and AUPR
auroc = roc_auc_score(y_test, y_pred_proba)
aupr = average_precision_score(y_test, y_pred_proba)
# Predict binary labels based on probability threshold
y_pred = (y_pred_proba > 0.5).astype(int)
# Calculate accuracy, precision, recall, and F1-score
accuracy = accuracy_score(y_test, y_pred)
precision = precision_score(y_test, y_pred)
recall = recall_score(y_test, y_pred)
f1 = f1_score(y_test, y_pred)
# Print evaluation metrics
print(f"AUROC: {auroc:.4f}")
print(f"AUPR: {aupr:.4f}")
print(f"Accuracy: {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1-score: {f1:.4f}")
# Plot ROC curve
fpr, tpr, _ = roc_curve(y_test, y_pred_proba)
plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, color='blue', lw=2, label='ROC curve (AUROC = %0.4f)' % auroc)
plt.plot([0, 1], [0, 1], color='gray', linestyle='--')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic (ROC) Curve')
plt.legend(loc="lower right")
plt.show()
# Plot precision-recall curve
precision, recall, _ = precision_recall_curve(y_test, y_pred_proba)
plt.figure(figsize=(8, 6))
plt.plot(recall, precision, color='red', lw=2, label='Precision-Recall curve (AUPR = %0.4f)' % aupr)
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Precision-Recall Curve')
plt.legend(loc="lower left")
plt.show()
Visualization¶
In [343]:
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
# Create a new graph for predicted links
G_predicted = nx.Graph()
# Add predicted nodes and edges
for u, v in gcn_lr_predicted_links:
G_predicted.add_node(u, bipartite=0)
G_predicted.add_node(v, bipartite=1)
G_predicted.add_edge(u, v)
# All edges in data
all_known_edges = [(ls_drug[i], ls_protein[j]) for i, j in zip(*np.where(np_mat_drug_protein == 1))]
# Add known edges between the predicted nodes
known_edges = [(u, v) for u, v in gcn_lr_predicted_links if np_mat_drug_protein[ls_drug.index(u)][ls_protein.index(v)] == 1]
# Identify nodes involved in predicted edges
predicted_nodes = set([u for u, v in gcn_lr_predicted_links] + [v for u, v in gcn_lr_predicted_links])
# Identify all known edges connected to predicted nodes but not in predicted edges
additional_edges = [
(u, v) for u, v in all_known_edges
if (u in predicted_nodes or v in predicted_nodes) and (u, v) not in predicted_edges and (u, v) not in known_edges
]
# Ensure all nodes in additional_edges are added to the graph
for u, v in additional_edges:
if u not in G_predicted:
G_predicted.add_node(u, bipartite=0 if u in ls_drug else 1)
if v not in G_predicted:
G_predicted.add_node(v, bipartite=0 if v in ls_drug else 1)
# Draw the predicted links graph
plt.figure(figsize=(20, 20))
pos = nx.spring_layout(G_predicted, k=0.1, iterations=5) # Adjust k for node spacing
# Draw known edges in green
nx.draw_networkx_edges(G_predicted, pos, edgelist=known_edges, edge_color='green', width=1)
# Draw predicted links in red
predicted_edges = [(u, v) for u, v in gcn_lr_predicted_links if np_mat_drug_protein[ls_drug.index(u)][ls_protein.index(v)] == 0]
nx.draw_networkx_edges(G_predicted, pos, edgelist=predicted_edges, edge_color='red', width=0.35)
# Draw additional known edges in black
nx.draw_networkx_edges(G_predicted, pos, edgelist=additional_edges, edge_color='gray', width=0.35)
# Draw nodes with different colors for drugs and proteins
drug_nodes = [node for node in G_predicted.nodes if node in ls_drug]
protein_nodes = [node for node in G_predicted.nodes if node in ls_protein]
nx.draw_networkx_nodes(G_predicted, pos, nodelist=drug_nodes, node_color='lightgreen', node_size=50, label='Drugs')
nx.draw_networkx_nodes(G_predicted, pos, nodelist=protein_nodes, node_color='pink', node_size=50, label='Proteins')
# Draw labels
# nx.draw_networkx_labels(G_predicted, pos, font_size=6, font_family='sans-serif')
plt.title("Predicted Drug-Protein Interaction Network")
plt.legend(markerscale=0.35)
plt.savefig('DTINet/plots/_predicted_interaction_network.png', format='png', dpi=300)
plt.show()
In [ ]:
Combined bipartide graphs¶
In [208]:
import networkx as nx
import numpy as np
# Create bipartite graphs for each interaction type
G_drug_protein = nx.Graph()
G_drug_disease = nx.Graph()
G_drug_side_effect = nx.Graph()
# Add drug and protein nodes to G_drug_protein
for drug in ls_drug:
G_drug_protein.add_node(drug, bipartite=0)
for protein in ls_protein:
G_drug_protein.add_node(protein, bipartite=1)
# Add edges to G_drug_protein
for i, drug in enumerate(ls_drug):
for j, protein in enumerate(ls_protein):
if np_mat_drug_protein[i, j] == 1:
G_drug_protein.add_edge(drug, protein)
# Add drug and disease nodes to G_drug_disease
for drug in ls_drug:
G_drug_disease.add_node(drug, bipartite=0)
for disease in ls_disease:
G_drug_disease.add_node(disease, bipartite=1)
# Add edges to G_drug_disease
for i, drug in enumerate(ls_drug):
for j, disease in enumerate(ls_disease):
if np_mat_drug_disease[i, j] == 1:
G_drug_disease.add_edge(drug, disease)
# Add drug and side effect nodes to G_drug_side_effect
for drug in ls_drug:
G_drug_side_effect.add_node(drug, bipartite=0)
for side_effect in ls_sideEffect:
G_drug_side_effect.add_node(side_effect, bipartite=1)
# Add edges to G_drug_side_effect
for i, drug in enumerate(ls_drug):
for j, side_effect in enumerate(ls_sideEffect):
if np_mat_drug_sideEffect[i, j] == 1:
G_drug_side_effect.add_edge(drug, side_effect)
# Print the number of nodes and edges in each graph
print(f'G_drug_protein: {G_drug_protein.number_of_nodes()} nodes, {G_drug_protein.number_of_edges()} edges')
print(f'G_drug_disease: {G_drug_disease.number_of_nodes()} nodes, {G_drug_disease.number_of_edges()} edges')
print(f'G_drug_side_effect: {G_drug_side_effect.number_of_nodes()} nodes, {G_drug_side_effect.number_of_edges()} edges')
G_drug_protein: 2201 nodes, 1920 edges G_drug_disease: 6311 nodes, 199214 edges G_drug_side_effect: 4900 nodes, 80164 edges
Extracting Traditional Graph Features¶
In [282]:
import networkx as nx
import numpy as np
from tqdm import tqdm
def common_neighbors(graph, node1, node2):
return len(list(nx.common_neighbors(graph, node1, node2)))
def jaccard_coefficient(graph, node1, node2):
union_size = len(set(graph.neighbors(node1)).union(set(graph.neighbors(node2))))
if union_size == 0:
return 0
return len(list(nx.common_neighbors(graph, node1, node2))) / union_size
def adamic_adar_index(graph, node1, node2):
aa_index = sum(1 / np.log(len(list(graph.neighbors(neighbor)))) for neighbor in nx.common_neighbors(graph, node1, node2))
return aa_index
# Create a list to hold features for each node pair
traditionalFeatures = []
# Compute features for each drug-protein pair
for drug in tqdm(ls_drug, desc="Computing features for drug-protein pairs"):
for protein in ls_protein:
if (drug, protein) in G_drug_protein.edges or (protein, drug) in G_drug_protein.edges:
common_neighbors_count = common_neighbors(G_drug_protein, drug, protein)
jaccard = jaccard_coefficient(G_drug_protein, drug, protein)
adamic_adar = adamic_adar_index(G_drug_protein, drug, protein)
traditionalFeatures.append([drug, protein, common_neighbors_count, jaccard, adamic_adar])
# Convert the list of features to a numpy array
traditional_features_array = np.array(traditionalFeatures)
Computing features for drug-protein pairs: 100%|████████████████████████████████████| 708/708 [00:00<00:00, 732.83it/s]
In [283]:
# Save the features as a NumPy array
save_path = 'DTINet/data/traditional_features.npy'
np.save(save_path, traditional_features_array)
print(f"Features saved to {save_path}")
Features saved to DTINet/data/traditional_features.npy
In [328]:
max_common_neighbors = traditional_features_df["common_neighbors"].max()
max_jaccard = traditional_features_df["jaccard"].max()
max_adamic_adar = traditional_features_df["adamic_adar"].max()
print("Max Common Neighbors:", max_common_neighbors)
print("Max Jaccard Coefficient:", max_jaccard)
print("Max Adamic/Adar Index:", max_adamic_adar)
Max Common Neighbors: 0 Max Jaccard Coefficient: 0.0 Max Adamic/Adar Index: 0
Combined GCN features¶
In [292]:
from tqdm import tqdm
# Define and train your GCN model
def train_gcn(G):
# Add 'node_type' attribute to all nodes in the graph
for node in G.nodes():
G.nodes[node]['node_type'] = 'drug' # You can assign the appropriate node type here
# Convert graph to PyTorch Geometric data object
x = torch.eye(len(G.nodes())) # Identity matrix as feature matrix
# Extract edge indices
edge_index = []
for edge in G.edges():
node1, node2 = edge
edge_index.append([list(G.nodes()).index(node1), list(G.nodes()).index(node2)])
edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
data = Data(x=x, edge_index=edge_index)
# Initialize the model
model = GCN()
# Training settings
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()
# Training loop with progress bar
model.train()
epochs = 200
progress_bar = tqdm(range(epochs), desc="Training GCN")
for epoch in progress_bar:
optimizer.zero_grad()
out = model(data)
loss = criterion(out, torch.tensor([G.nodes[node]['node_type'] == 'drug' for node in G.nodes], dtype=torch.long))
loss.backward()
optimizer.step()
progress_bar.set_postfix({'Loss': loss.item()})
# Evaluate the model
model.eval()
with torch.no_grad():
out = model(data).numpy()
# Convert embeddings to a dictionary
embeddings = {node: embedding for node, embedding in zip(G.nodes, out)}
return embeddings
# Train the GCN and get node embeddings
G_combined = nx.compose_all([G_drug_protein, G_drug_disease, G_drug_side_effect])
embeddings = train_gcn(G_combined)
# Convert embeddings to a DataFrame
embeddings_df = pd.DataFrame.from_dict(embeddings, orient='index')
Training GCN: 100%|███████████████████████████████████████████████████| 200/200 [00:34<00:00, 5.85it/s, Loss=0.000712]
In [293]:
# Save the embedding matrix
embedding_matrix = embeddings_df.to_numpy()
save_path = 'DTINet/data/combined_gcn_features.npy'
np.save(save_path, embedding_matrix)